#!/usr/bin/env python3
"""
简单演示：Auto-Encoding Bayesian Inverse Games

展示VAE和MCP求解器的基本功能
"""

import sys
import os

# 解决OpenMP冲突问题
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# 添加src和games到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'games'))

import torch
import numpy as np

# 解决Qt平台插件问题 - 使用Agg后端（非交互式）
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from typing import Dict, Any

from mcp import MCPGameSolver
from drone_game import DroneGame


def get_demo_result_path() -> str:
    """Return absolute path for demo figure storage (返回演示图像的绝对路径)."""
    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    output_dir = os.path.join(repo_root, 'results', 'demo_path')
    os.makedirs(output_dir, exist_ok=True)
    return os.path.join(output_dir, 'pathsolverresult.png')


def demo_game_solver(use_path: bool = True):  # 默认使用PATH
    """演示博弈求解器
    
    Args:
        use_path: 是否使用PATH求解器（需要Julia环境）
    """
    print("\n=== 3D无人机避碰博弈演示 ===")

    # 创建无人机博弈（3D模型）
    game = DroneGame(
        dt=0.1,  # 减小时间步长以获得更精确的控制
        collision_radius=2.0,  # 增加碰撞半径以确保安全
        control_limits={
            'theta': 0.5,   # 增加俯仰角限制
            'phi': 0.5,     # 增加滚转角限制
            'tau': 20.0     # 推力限制
        },
        velocity_limit=6.0,  # 降低速度限制以更安全
        altitude_limits=(15.0, 55.0),  # 高度限制：15-55米
        g=9.81
    )
    
    print(f"博弈设置:")
    print(f"  玩家数: {game.n_players}")
    print(f"  状态维度: {game.n_states} (3D: px, py, pz, vx, vy, vz × 2)")
    print(f"  控制维度: {game.n_controls} (theta, phi, tau × 2)")
    print(f"  无人机1: 从 {game.initial_positions[0].tolist()} 到 {game.target_positions[0].tolist()}")
    print(f"  无人机2: 从 {game.initial_positions[1].tolist()} 到 {game.target_positions[1].tolist()}")
    print(f"  碰撞半径: {game.collision_radius} m (安全距离: {game.collision_radius * 1.2:.1f} m)")
    print(f"  高度限制: {game.altitude_limits[0]:.1f} ~ {game.altitude_limits[1]:.1f} m")
    print(f"  速度限制: ±{game.velocity_limit} m/s")
    print(f"  重力加速度: {game.g} m/s^2")

    # 统一MPC配置：预测10步，仅执行1步（和MPGP demo保持一致的循环次数）
    control_horizon = 1      # 执行步长
    prediction_horizon = 10  # 每次规划向前预测10步
    n_mpc_steps = 50         # MPC循环次数，与MPGP_demo.py一致

    # 选择求解器
    if use_path:
        print("\n使用PATH求解器")
        solver_type = "path"
        solver_params = {
            'tolerance': 1e-4,  # 提高精度要求
            'verbose': True,
            'max_iterations': 100000,
            'major_iteration_limit': 1000,  # 增加主迭代限制
            'minor_iteration_limit': 50000,  # 增加次迭代限制
            'time_limit': 1200.0  # 20分钟时间限制
        }
    else:
        print("\n使用简单迭代求解器")
        solver_type = "simple"
        solver_params = {
            'learning_rate': 0.01,  # 学习率
            'max_iterations': 10000,  # 增加迭代次数
            'tolerance': 1e-2,  # 放宽容差
            'verbose': True
        }
    
    # solver horizon等于MPC的预测窗口
    horizon = prediction_horizon

    # 创建博弈求解器
    game_solver = MCPGameSolver(
        game=game,
        horizon=horizon,
        solver_type=solver_type,
        solver_params=solver_params
    )
    
    print(f"\n求解器设置:")
    print(f"  求解器类型: {solver_type}")
    print(f"  时间步数: {game_solver.horizon}")
    print(f"  决策变量数: {game_solver.n_vars}")
    print(f"  约束数: {game_solver.n_constraints}")
    print(f"  MCP变量总数: {game_solver.n_mcp_vars}")
    print(f"  MPC预测窗口: {prediction_horizon} 步 ({prediction_horizon * game.dt:.2f} 秒)")
    print(f"  执行步数: {control_horizon} 步 (每次滚动执行1步)")
    print(f"  MPC循环次数: {n_mpc_steps} (保持与MPGP demo一致)")

    # 获取初始状态
    initial_state = game.get_initial_state()
    print(f"\n初始状态: {initial_state.tolist()}")

    # 求解博弈
    print("\n开始求解博弈...")
    try:
        print("\n开始求解博弈（MPC滚动执行）...")
        current_state = initial_state.clone()
        executed_states = [current_state.clone()]
        executed_controls = []
        iteration_results = []
        
        for mpc_idx in range(n_mpc_steps):
            print(f"\n--- MPC迭代 {mpc_idx + 1}/{n_mpc_steps} ---")
            iter_result = game_solver.solve_game(initial_state=current_state)
            iteration_results.append(iter_result)
            
            residual = iter_result.get('residual', float('nan'))
            status = iter_result.get('status', 'unknown')
            print(f"  成功: {iter_result['success']} | 状态: {status} | 残差: {residual:.6e}")
            if 'stationarity' in iter_result:
                print(f"  静态性: {iter_result['stationarity']:.6e}")
                print(f"  可行性: {iter_result['feasibility']:.6e}")
                print(f"  互补性: {iter_result['complementarity']:.6e}")
            
            states_plan = iter_result.get('states')
            controls_plan = iter_result.get('controls')
            if states_plan is None or controls_plan is None:
                print("  警告: 求解结果缺少轨迹，停止MPC循环。")
                break
            
            print(f"  计划轨迹: states={states_plan.shape}, controls={controls_plan.shape}")
            exec_steps = min(control_horizon, states_plan.shape[0] - 1, controls_plan.shape[0])
            if exec_steps <= 0:
                print("  警告: 规划结果不足以执行任何一步，提前终止。")
                break
            
            print(f"  执行 {exec_steps} 步控制 (预测{prediction_horizon}步 → 执行{control_horizon}步)")
            for k in range(exec_steps):
                executed_controls.append(controls_plan[k].detach().clone())
                executed_states.append(states_plan[k + 1].detach().clone())
            
            current_state = executed_states[-1]
        
        total_executed = len(executed_controls)
        print(f"\nMPC执行摘要: 规划 {len(iteration_results)} 次，实际执行 {total_executed} 步 ({total_executed * game.dt:.2f} 秒)")
        
        states = torch.stack(executed_states, dim=0)
        if executed_controls:
            controls = torch.stack(executed_controls, dim=0)
        else:
            controls = torch.zeros((0, game.n_controls), dtype=states.dtype, device=states.device)
        
        overall_success = len(iteration_results) == n_mpc_steps and all(r['success'] for r in iteration_results)
        
        # 检查碰撞
        min_distance = float('inf')
        collision_times = []
        
        for t in range(states.shape[0]):
            pos1 = states[t, 0:3]
            pos2 = states[t, 6:9]
            dist = torch.norm(pos1 - pos2).item()
            min_distance = min(min_distance, dist)
            
            if dist < game.collision_radius:
                collision_times.append(t)
        
        print(f"\n安全性检查")
        print(f"  最小距离: {min_distance:.3f} m")
        print(f"  碰撞半径: {game.collision_radius} m")
        print(f"  是否安全: {'是' if min_distance >= game.collision_radius else '否'}")
        
        # 计算到达目标的距离
        final_state = states[-1]
        final_pos1 = final_state[0:3]
        final_pos2 = final_state[6:9]
        
        dist_to_target1 = torch.norm(final_pos1 - game.target_positions[0]).item()
        dist_to_target2 = torch.norm(final_pos2 - game.target_positions[1]).item()
        
        print(f"\n目标达成:")
        print(f"  无人机1到目标距离: {dist_to_target1:.3f} m")
        print(f"  无人机2到目标距离: {dist_to_target2:.3f} m")
        print(f"  无人机1最终位置: {final_pos1.tolist()}")
        print(f"  无人机2最终位置: {final_pos2.tolist()}")
        
        # 使用与MCP CasADi目标一致的权重计算执行轨迹成本
        total_cost_1 = 0.0
        total_cost_2 = 0.0
        executed_horizon = controls.shape[0]

        def _mcp_position_cost(pos: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            # 与 src/mcp/mcp_solver.py::_casadi_cost 保持同权重（XY 0.01，Z 加权 5×）
            diff = pos - target.to(pos.device)
            return 0.01 * (torch.sum(diff[0:2] ** 2) + 5.0 * torch.sum(diff[2:3] ** 2))

        for t in range(executed_horizon):
            state_t = states[t]
            cost_1 = _mcp_position_cost(state_t[0:3], game.target_positions[0])
            cost_2 = _mcp_position_cost(state_t[6:9], game.target_positions[1])

            total_cost_1 += cost_1.item()
            total_cost_2 += cost_2.item()
        
        print(f"\n成本分析:")
        print(f"  无人机1总成本: {total_cost_1:.3f}")
        print(f"  无人机2总成本: {total_cost_2:.3f}")
        
        # 可视化轨迹
        save_path = get_demo_result_path()
        game.visualize_trajectory(states, save_path=save_path)
        print(f"  轨迹图保存路径: {save_path}")
        
        summary = {
            'success': overall_success,
            'status': iteration_results[-1].get('status', 'not_started') if iteration_results else 'not_started',
            'residual': iteration_results[-1].get('residual', float('nan')) if iteration_results else float('nan'),
            'states': states,
            'controls': controls,
            'iterations': len(iteration_results),
            'executed_steps': total_executed,
            'iteration_results': iteration_results
        }
        
        return game_solver, summary
        
    except Exception as e:
        print(f"博弈求解出错: {e}")
        import traceback
        traceback.print_exc()
        return game_solver, None


def main():
    """主函数"""
    print("=" * 60)
    print("Auto-Encoding Bayesian Inverse Games - 无人机避碰演示")
    print("=" * 60)

    # 检查是否使用PATH求解器
    use_path = True  # 默认使用PATH
    if len(sys.argv) > 1 and sys.argv[1] == '--simple':
        use_path = False
        print("\n使用简单迭代求解器")
    else:
        print("\n使用PATH求解器（需要Julia和PATHSolver.jl）")
        print("提示: 使用 '--simple' 参数可切换到简单迭代求解器")

    try:
        # 运行博弈求解器演示
        solver, result = demo_game_solver(use_path=use_path)

        print("\n" + "=" * 60)
        if result and result['success']:
            print("演示完成! 🎉")
            print("生成的文件:")
            print(f"- 轨迹图: {get_demo_result_path()}")
        else:
            print("演示完成，但求解未收敛")
            print("\n建议:")
            if use_path:
                print("1. 检查Julia和PATHSolver.jl是否正确安装")
                print("2. 调整博弈参数（时间步长、碰撞半径等）")
                print("3. 尝试简单求解器: python simple_demo.py --simple")
            else:
                print("1. 增加迭代次数")
                print("2. 调整学习率或容差")
                print("3. 使用PATH求解器: python simple_demo.py")

    except Exception as e:
        print(f"\n演示过程中出错: {e}")
        import traceback
        traceback.print_exc()
        
        if use_path and "Julia" in str(e):
            print("\n看起来Julia或PATH求解器未正确配置。")
            print("请参考 docs/PATH_SETUP.md 进行配置。")
            print("或者使用简单求解器: python simple_demo.py --simple")


if __name__ == "__main__":
    main()
